% pull_power_timecourses_wholePostINT
% focused analyses on alpha effects. Unlike analyses in
% pull_power_timecourses, which are restricted to during the INT task (for
% aud WM) and after the INT task (for vis WM), these analyses examine the
% entire time window between INT task onset and the end of WM retention for
% both WM modalites.

badSs = [19];
doStat = true; % set to false to skip CBPT

int_st = 1.5; % (sec) time after start of INT win that INT stim starts
int_dur = 1; % (sec) approximate duration of the INT task
range_alpha = 8:12; % Hz
range_alpha_high = 11:15;
range_theta = 4:7;
alpha_chans_enc = {'PO8', 'PO4', 'O2', 'Pz', 'P2'}; % central-occipital and lateralized
alpha_chans_durINT = {'PO8', 'PO4', 'O2', 'Pz', 'P2'};
theta_chans = {'Afz', 'Fpz', 'Fp1', 'Fp2'};

compname = getenv('computername');
if strcmp(compname,'DESKTOP-BHR0AU7')
    rawdir = 'E:\ANL\Experiments\DATA\VAST\RAW\';
    datadir = 'E:\ANL\Experiments\DATA\VAST\PREPROCESSED\TFR\';
    resultsdir = 'E:\ANL\Experiments\RESULTS\VAST\EEG\';
    data4Rdir = 'E:\ANL\Experiments\RESULTS\VAST\DATAforR\';
    addpath('C:\Users\jtjus\Documents\MATLAB\fieldtrip-20180512\fieldtrip-20180512');
    addpath(genpath('E:\MATLAB\CSDtoolbox'));
elseif strcmp(compname,'MSI')
    rawdir = 'E:\ANL\RAW_DATA\VAST\';
    datadir = 'E:\ANL\PREPROCESSED_DATA\VAST\TFR\';
    resultsdir = 'E:\ANL\RESULTS\VAST\EEG\';
    data4Rdir = 'E:\ANL\RESULTS\VAST\DATAforR\';
    addpath('D:\MATLAB\fieldtrip-20180809\fieldtrip-20180809');
    addpath(genpath('D:\MATLAB\CSDtoolbox'));
else
    error('need to set up paths for this machine.')
end

%% Load and format individual subjects' TFR data
temp = load(strcat(resultsdir,'indSs_TFR.mat'));
indSs = temp.indSs;
indSs_diff = temp.indSs_diff;
clear temp
indSs(badSs) = [];
indSs_diff(badSs) = [];

Nss = length(indSs);
SEM_denom = sqrt(Nss);
Nconds = size(indSs{1}, 2);
cond_names = indSs{1}(1,:);
cond_names_diff = cond_names(~contains(cond_names, 'none'));
Nconds_diff = length(cond_names_diff);
Nwins = size(indSs{1}{2,1}, 2);
win_names = indSs{1}{2,1}(1,:);
win_names_diff = win_names(~contains(win_names, 'enc'));
Nwins_diff = length(win_names_diff);

% Make a combined FT struct for putting in power time courses like ERPs
tempTFR = load('dummy_indSs_TFR_80Hz.mat');
tempTFR = tempTFR.dummy_indSs_TFR_80Hz;
labels_ch = tempTFR.label;
labels_f = tempTFR.freq;
BL_len = 1;
if tempTFR.time(1) == 0
    tempTFR.time = tempTFR.time - BL_len;
end
general_time = tempTFR.time;
clear tempTFR
fakeTFR = load('dummy_indSs_full.mat');
fakeTFR = fakeTFR.dummy_indSs_full;
if isfield(fakeTFR.cfg, "previous")
   fakeTFR.cfg = rmfield(fakeTFR.cfg, "previous");
end


%% Truncate the power spectra
% Step 1: Find the min TFR length in each window across WM conditions
min_lens = zeros(Nconds, Nwins, Nss);
for c = 1:Nconds
    for w = 1:Nwins
        for s = 1:Nss
            min_lens(c,w,s) = size(indSs{s}{2,c}{2,w}, 3); % time
        end
    end
end
min_lens = min(min_lens, [], 3); % min across subjects
min_lens = min(min_lens, [], 1); % min across conditions

% Step 2: Loop again, truncating individual power spectra
for w = 1:Nwins
    currmin = min_lens(w);
    for c = 1:Nconds
        for s = 1:Nss
            if size(indSs{s}{2,c}{2,w}, 3) > currmin
                if strcmp(indSs{s}{2,c}{1,w}, 'int') % cut off end
                    indSs{s}{2,c}{2,w} = indSs{s}{2,c}{2,w}(:,:,1:currmin);
                else % cut off beginning
                    toCut = size(indSs{s}{2,c}{2,w}, 3) - currmin;
                    indSs{s}{2,c}{2,w}(:,:,1:toCut) = [];
%                     indSs{s}{2,c}{2,w} = indSs{s}{2,c}{2,w}(:,:,1:currmin);
                end
            end
        end
    end
end

% Step 3: Repeat the above for the difference TFRs
min_lens_diff = zeros(Nconds_diff, Nwins_diff, Nss);
for c = 1:Nconds_diff
    for w = 1:Nwins_diff
        for s = 1:Nss
            min_lens_diff(c,w,s) = size(indSs_diff{s}{2,c}{2,w}, 3); % time
        end
    end
end
min_lens_diff = min(min_lens_diff, [], 3); % min across subjects
min_lens_diff = min(min_lens_diff, [], 1); % min across conditions
for w = 1:Nwins_diff
    currmin = min_lens_diff(w);
    for c = 1:Nconds_diff
        for s = 1:Nss
            if size(indSs_diff{s}{2,c}{2,w}, 3) > currmin
                if strcmp(indSs_diff{s}{2,c}{1,w}, 'int') % cut off end
                    indSs_diff{s}{2,c}{2,w} = indSs_diff{s}{2,c}{2,w}(:,:,1:currmin);
                else % cut off beginning
                    toCut = size(indSs_diff{s}{2,c}{2,w}, 3) - currmin;
                    indSs_diff{s}{2,c}{2,w}(:,:,1:toCut) = [];
%                     indSs{s}{2,c}{2,w} = indSs{s}{2,c}{2,w}(:,:,1:currmin);
                end
            end
        end
    end
end

tb_enc = general_time(1:min_lens(1));
tb_int = general_time(1:min_lens(2));
tb_noBL = general_time(general_time >= BL_len);
find_tb_durINT = tb_int >= int_st & tb_int < int_st + int_dur; % during 1 sec INT task
find_tb_preINT = tb_int >= 0 & tb_int < int_st; % before INT task onset
find_tb_postINT = tb_int >= int_st + int_dur; % after INT task offset
find_tb_wholeINT = tb_int > 0; % entire retention (baseline already negative)
% KEY TIME RANGE FOR THIS SCRIPT
find_tb_beforeOnset = tb_int < int_st; % pre-INT phase, for finding IAF
find_tb_afterOnset = tb_int >= int_st; % retention starting at INT onset


%% Alpha timecourse testing

% Get the individual alpha frequencies (IAFs) for each subject
% NEW: for aud, MAX in the pre-INT onset phase
%      for vis, MIN in the pre-INT onset phase
IAFs_retention = zeros(Nss, Nconds);
w = strcmp(indSs{1}{2,1}(1,:), 'int');
ch = ismember(labels_ch, alpha_chans_durINT);
f = ismember(labels_f, range_alpha);
for s = 1:Nss
    for c = 1:Nconds
        x = indSs{s}{2,c}{2,w}(ch, f, find_tb_beforeOnset);
        x = squeeze(mean(x, 1)); % average across chans
        x = squeeze(mean(x, 2)); % average across new time dimension
        if contains(cond_names{c}, {'as_','at_'}) % aud WM
            [~, y] = max(x);
        else % alpha suppressed for vis, take min
            [~, y] = min(x);
        end         
        IAFs_retention(s,c) = range_alpha(y);
    end
end
IAFs_retention = squeeze(round(mean(IAFs_retention, 2))); % average across conditions

% Pull the time courses
alpha_afterOnset = cell(Nss, Nconds);
for s = 1:Nss
    % One Hz above and below IAF
%     f_iaf = ismember(labels_f, IAFs_retention(s)-1:IAFs_retention(s)+1);
    % Take at IAF only
    f_iaf = ismember(labels_f, IAFs_retention(s));
    for c = 1:Nconds
        alpha_afterOnset{s, c} = fakeTFR;
        x = indSs{s}{2,c}{2,w}(:, f_iaf, find_tb_afterOnset);
        x = squeeze(mean(x, 2)); % average across (near) IAFs
        alpha_afterOnset{s, c}.avg = x;
        [alpha_afterOnset{s, c}.var, alpha_afterOnset{s, c}.dof] = deal(zeros(size(x)));
        alpha_afterOnset{s, c}.time = tb_int(find_tb_afterOnset);
    end
end
% Alpha by WM modality and INT task, collapsed across WM domain 
alpha_afterOnset_domCol = cell(Nss, 6); % 3 INT conditions x aud|vis WM
w = strcmp(indSs{1}{2,1}(1,:), 'int');
for s = 1:Nss
    % One Hz above and below IAF
%     f_iaf = ismember(labels_f, IAFs_retention(s)-1:IAFs_retention(s)+1);
    % Take at IAF only
    f_iaf = ismember(labels_f, IAFs_retention(s));
    
    % No INT -- aud WM
    alpha_afterOnset_domCol{s,1} = fakeTFR;
    x1 = indSs{s}{2,strcmp(indSs{s}(1,:),'at_none')}{2,w}(:, f_iaf, find_tb_afterOnset);
    x1 = squeeze(mean(x1, 2)); % average across (near) IAFs
    x2 = indSs{s}{2,strcmp(indSs{s}(1,:),'as_none')}{2,w}(:, f_iaf, find_tb_afterOnset);
    x2 = squeeze(mean(x2, 2)); % average across (near) IAFs
    x = mean(cat(3, x1, x2), 3);
    alpha_afterOnset_domCol{s,1}.avg = x;
    [alpha_afterOnset_domCol{s,1}.var, alpha_afterOnset_domCol{s,1}.dof] =...
        deal(zeros(size(x)));
    alpha_afterOnset_domCol{s,1}.time = tb_int(find_tb_afterOnset);
    % AT INT -- aud WM
    alpha_afterOnset_domCol{s,2} = fakeTFR;
    x1 = indSs{s}{2,strcmp(indSs{s}(1,:),'at_at')}{2,w}(:, f_iaf, find_tb_afterOnset);
    x1 = squeeze(mean(x1, 2)); % average across (near) IAFs
    x2 = indSs{s}{2,strcmp(indSs{s}(1,:),'as_at')}{2,w}(:, f_iaf, find_tb_afterOnset);
    x2 = squeeze(mean(x2, 2)); % average across (near) IAFs
    x = mean(cat(3, x1, x2), 3);
    alpha_afterOnset_domCol{s,2}.avg = x;
    [alpha_afterOnset_domCol{s,2}.var, alpha_afterOnset_domCol{s,2}.dof] =...
        deal(zeros(size(x)));
    alpha_afterOnset_domCol{s,2}.time = tb_int(find_tb_afterOnset);
    % AS INT -- aud WM
    alpha_afterOnset_domCol{s,3} = fakeTFR;
    x1 = indSs{s}{2,strcmp(indSs{s}(1,:),'at_as')}{2,w}(:, f_iaf, find_tb_afterOnset);
    x1 = squeeze(mean(x1, 2)); % average across (near) IAFs
    x2 = indSs{s}{2,strcmp(indSs{s}(1,:),'as_as')}{2,w}(:, f_iaf, find_tb_afterOnset);
    x2 = squeeze(mean(x2, 2)); % average across (near) IAFs
    x = mean(cat(3, x1, x2), 3);
    alpha_afterOnset_domCol{s,3}.avg = x;
    [alpha_afterOnset_domCol{s,3}.var, alpha_afterOnset_domCol{s,3}.dof] =...
        deal(zeros(size(x)));
    alpha_afterOnset_domCol{s,3}.time = tb_int(find_tb_afterOnset);
    
    % No INT -- vis WM
    alpha_afterOnset_domCol{s,4} = fakeTFR;
    x1 = indSs{s}{2,strcmp(indSs{s}(1,:),'vt_none')}{2,w}(:, f_iaf, find_tb_afterOnset);
    x1 = squeeze(mean(x1, 2)); % average across (near) IAFs
    x2 = indSs{s}{2,strcmp(indSs{s}(1,:),'vs_none')}{2,w}(:, f_iaf, find_tb_afterOnset);
    x2 = squeeze(mean(x2, 2)); % average across (near) IAFs
    x = mean(cat(3, x1, x2), 3);
    alpha_afterOnset_domCol{s,4}.avg = x;
    [alpha_afterOnset_domCol{s,4}.var, alpha_afterOnset_domCol{s,4}.dof] =...
        deal(zeros(size(x)));
    alpha_afterOnset_domCol{s,4}.time = tb_int(find_tb_afterOnset);
    % AT INT -- vis WM
    alpha_afterOnset_domCol{s,5} = fakeTFR;
    x1 = indSs{s}{2,strcmp(indSs{s}(1,:),'vt_at')}{2,w}(:, f_iaf, find_tb_afterOnset);
    x1 = squeeze(mean(x1, 2)); % average across (near) IAFs
    x2 = indSs{s}{2,strcmp(indSs{s}(1,:),'vs_at')}{2,w}(:, f_iaf, find_tb_afterOnset);
    x2 = squeeze(mean(x2, 2)); % average across (near) IAFs
    x = mean(cat(3, x1, x2), 3);
    alpha_afterOnset_domCol{s,5}.avg = x;
    [alpha_afterOnset_domCol{s,5}.var, alpha_afterOnset_domCol{s,5}.dof] =...
        deal(zeros(size(x)));
    alpha_afterOnset_domCol{s,5}.time = tb_int(find_tb_afterOnset);
    % AS INT -- vis WM
    alpha_afterOnset_domCol{s,6} = fakeTFR;
    x1 = indSs{s}{2,strcmp(indSs{s}(1,:),'vt_as')}{2,w}(:, f_iaf, find_tb_afterOnset);
    x1 = squeeze(mean(x1, 2)); % average across (near) IAFs
    x2 = indSs{s}{2,strcmp(indSs{s}(1,:),'vs_as')}{2,w}(:, f_iaf, find_tb_afterOnset);
    x2 = squeeze(mean(x2, 2)); % average across (near) IAFs
    x = mean(cat(3, x1, x2), 3);
    alpha_afterOnset_domCol{s,6}.avg = x;
    [alpha_afterOnset_domCol{s,6}.var, alpha_afterOnset_domCol{s,6}.dof] =...
        deal(zeros(size(x)));
    alpha_afterOnset_domCol{s,6}.time = tb_int(find_tb_afterOnset);
end

% Compute mean time courses -- raw TFRs
alpha_afterOnset_mean = [cond_names; cell(1, Nconds)];
ftcfg = [];
ftcfg.parameter = 'avg';
avgstr = "alpha_afterOnset_mean{2, c} = ft_timelockgrandaverage(ftcfg, alpha_afterOnset{1, c},";
for s = 2:Nss-1
    avgstr = strcat(avgstr, "alpha_afterOnset{", num2str(s), ", c},");
end
avgstr = strcat(avgstr, "alpha_afterOnset{", num2str(Nss), ", c});");
for c = 1:Nconds
    eval(avgstr)
    if isfield(alpha_afterOnset_mean{2, c}.cfg, "previous")
        alpha_afterOnset_mean{2, c}.cfg = rmfield(alpha_afterOnset_mean{2, c}.cfg, "previous");
    end
end
% Compute mean time courses -- raw TFRs collapsed across domain
alpha_afterOnset_domCol_mean = [{'a_none','a_at','a_as','v_none','v_at','v_as'}; cell(1, 6)];
ftcfg = [];
ftcfg.parameter = 'avg';
avgstr = "alpha_afterOnset_domCol_mean{2, c} = ft_timelockgrandaverage(ftcfg,alpha_afterOnset_domCol{1, c},";
for s = 2:Nss-1
    avgstr = strcat(avgstr, "alpha_afterOnset_domCol{", num2str(s), ", c},");
end
avgstr = strcat(avgstr, "alpha_afterOnset_domCol{", num2str(Nss), ", c});");
for c = 1:6
    eval(avgstr)
    if isfield(alpha_afterOnset_domCol_mean{2, c}.cfg, "previous")
        alpha_afterOnset_domCol_mean{2, c}.cfg = rmfield(alpha_afterOnset_domCol_mean{2, c}.cfg, "previous");
    end
end

if doStat == true
    % CBPT
    statcfg = [];
    % ...neighbor structure
    neighbcfg = [];
    neighbcfg.method = 'distance';
    neighbcfg.layout = 'biosemi64.lay';
    neighbcfg.channel = 'all';
    neighbors = ft_prepare_neighbours(neighbcfg, alpha_afterOnset{1, 1});
    statcfg.neighbours = neighbors;
    % note: actual neighbor structure creation occurs in the main loop below
    % ...design matrix
    design(1,:) = repmat(1:Nss,1,2);
    design(2,:) = [ones(1,Nss), repmat(2,1,Nss)];
    statcfg.design = design;
    statcfg.uvar = 1;
    statcfg.ivar = 2;
    % ...general parameters
    statcfg.method = 'montecarlo';
    statcfg.statistic = 'depsamplesT'; % paired tests
    statcfg.tail = 0; % 0 = two-sided permutation test
    statcfg.clustertail = 0; % 0 = two-sided cluster significance test
    statcfg.correctm = 'cluster';
    statcfg.correcttail = 'prob'; % adjust p-values if tests are two-sided
    statcfg.clusteralpha = 0.05; % threshold for defining clusters
    statcfg.clusterstatistic = 'maxsum'; % take the largest cluster
    statcfg.minnbchan = 2;
    statcfg.alpha = 0.05; % threshold for significance of the permutation test
    statcfg.numrandomization = 2000;
    % All TFRs truncated the same, so can just take one subject's timecourse
    statcfg.latency = [1.5, alpha_afterOnset_domCol{1,1}.time(end)];

    % % Auditory WM, collapsed across domain
    STAT_alpha_afterOnset.aAT_aNone = ft_timelockstatistics(statcfg, alpha_afterOnset_domCol{:,2}, alpha_afterOnset_domCol{:,1});
    STAT_alpha_afterOnset.aAS_aNone = ft_timelockstatistics(statcfg, alpha_afterOnset_domCol{:,3}, alpha_afterOnset_domCol{:,1});
    STAT_alpha_afterOnset.aAT_aAS = ft_timelockstatistics(statcfg, alpha_afterOnset_domCol{:,2}, alpha_afterOnset_domCol{:,3});
    % % Auditory WM, collapsed across domain
    STAT_alpha_afterOnset.vAT_vNone = ft_timelockstatistics(statcfg, alpha_afterOnset_domCol{:,5}, alpha_afterOnset_domCol{:,4});
    STAT_alpha_afterOnset.vAS_vNone = ft_timelockstatistics(statcfg, alpha_afterOnset_domCol{:,6}, alpha_afterOnset_domCol{:,4});
    STAT_alpha_afterOnset.vAT_vAS = ft_timelockstatistics(statcfg, alpha_afterOnset_domCol{:,5}, alpha_afterOnset_domCol{:,6});
end


%% Remove "previous" fields from the STATs

if doStat == true
    % Note: no longer saving: alpha_durINT_high (and STAT)
    stat_names = {'STAT_alpha_afterOnset'};
    for j = 1:length(stat_names)
        sn = stat_names{j};
        eval(strcat("fn = fieldnames(",sn,");"))
        for i = 1:length(fn)
            currfn = fn{i};
            eval(strcat("if isfield(",sn,".",currfn,".cfg, 'previous') ",sn,".",currfn,".cfg=rmfield(",sn,".",currfn,".cfg, 'previous'); end"))
        end
    end


    %% Save
    save(strcat(resultsdir, 'freq_band_timecourses.mat'), 'tb_enc', 'tb_int',...
    'tb_noBL', 'find_tb_afterOnset', 'find_tb_preINT', 'find_tb_postINT', 'find_tb_wholeINT',...
    'alpha_enc', 'alpha_enc_mean', 'theta_preINT', 'theta_preINT_mean', 'theta_preINT_wmVS',...
    'theta_preINT_wmVT', 'theta_preINT_diff', 'theta_preINT_diff_mean', 'alpha_durINT',...
    'alpha_durINT_domCollapsed', 'alpha_durINT_domCollapsed_mean', 'alpha_durINT_mean',...
    'alpha_durINT_diff', 'alpha_durINT_diff_mean', 'alpha_wholeINT', 'alpha_wholeINT_mean',...
    'alpha_wholeINT_domCollapsed', 'alpha_wholeINT_diff', 'alpha_wholeINT_diff_mean',...
    'alpha_postINT', 'alpha_postINT_domCollapsed', 'alpha_postINT_domCollapsed_mean',...
    'alpha_postINT_diff', 'alpha_postINT_diff_asINT','alpha_postINT_diff_atINT',...
    'alpha_postINT_diff_mean', 'alpha_postINT_diff_asINT_mean','alpha_postINT_diff_atINT_mean',...
    'STAT_alpha_enc', 'STAT_theta_preINT','STAT_theta_preINT_diff', 'STAT_alpha_durINT',...
    'STAT_alpha_wholeINT', 'STAT_alpha_postINT','STAT_alpha_postINT_diff', '-v7.3')
end


%% Construct simplified variables for parametric statistical testing in R

% Alpha (during retention) and Theta (during pre-INT preiod) time courses 
% Alpha parameters...
ch_a = strcmp(alpha_wholeINT_mean{2,1}.label, 'Pz'); % Pz for alpha
tpts_alpha = size(alpha_wholeINT{1,1}.avg, 2); % already truncated to min
[alphapow_mean, alphapow_sem] = deal(zeros(Nconds, tpts_alpha)); 
[alphapow_mean_domCollapsed, alphapow_sem_domCollapsed] = deal(zeros(Nconds/2, tpts_alpha));
% Theta parameters...
ch_t = strcmp(theta_preINT_mean{2,1}.label, 'AFz'); % AFz for theta
tpts_theta = size(theta_preINT{1,1}.avg, 2); % already truncated to min
[thetapow_mean, thetapow_sem] = deal(zeros(Nconds, tpts_theta));
% collect data...
for c = 1:Nconds
    x_a = zeros(Nss, tpts_alpha);
    x_t = zeros(Nss, tpts_theta);
    for s = 1:Nss
        x_a(s,:) = alpha_wholeINT{s,c}.avg(ch_a,:);
        x_t(s,:) = theta_preINT{s,c}.avg(ch_t,:);
    end
    alphapow_mean(c,:) = squeeze(mean(x_a, 1));
    alphapow_sem(c,:) = std(x_a) ./ sqrt(Nss);
    thetapow_mean(c,:) = squeeze(mean(x_t, 1));
    thetapow_sem(c,:) = std(x_t) ./ sqrt(Nss);
end
for c = 1:Nconds/2 % separate loop for domain-collapsed alpha data
    x_a = zeros(Nss, tpts_alpha);
    for s = 1:Nss
        x_a(s,:) = alpha_wholeINT_domCollapsed{s,c}.avg(ch_a,:);
    end
    alphapow_mean_domCollapsed(c,:) = squeeze(mean(x_a, 1));
    alphapow_sem_domCollapsed(c,:) = std(x_a) ./ sqrt(Nss);
end
tb_alpha = tb_int(find_tb_wholeINT) + 0.5; % 0.5 offset in trialfun
tb_theta = tb_int(find_tb_preINT) + 0.5;
save(strcat(data4Rdir, 'simple_timecourses.mat'), 'cond_names', 'alphapow_mean',...
    'alphapow_sem', 'alphapow_mean_domCollapsed', 'alphapow_sem_domCollapsed',...
    'thetapow_mean', 'thetapow_sem', 'tb_alpha', 'tb_theta')

% Theta means
theta_means = cell2mat(cellfun(@(x) mean(x.avg(ch_t,:)), theta_preINT, 'UniformOutput', false));
save(strcat(data4Rdir, 'simple_theta_preINT.mat'), 'cond_names', 'theta_means')



